# ------------------------------------------------------------------------------
# Plots performance of limited bandwidth models
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 04_plot_performance.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(here)
library(yaml)
library(tidyverse)
library(glue)
library(ggthemes)
library(data.table)

temp <- here(
  "code", "06_physician_boundedness", "02_behavioral_gbms", "temp"
)
a <- modules::use(here("lib", "aesthetics.R"))
u <- modules::use(here("lib", "util.R"))
a$get_font("Optima", here("lib", "optima.ttf"))

# Helpers ----------------------------------------------------------------------
log_step <- function(x, step_size = 0.25) {
  log_x <- sort(unique(log10(x)))

  ii <- 2

  while (ii <= length(log_x)) {
    if (log_x[ii] - log_x[ii - 1] < step_size) {
      log_x <- log_x[-ii]
    } else {
      ii <- ii + 1
    }
  }

  log_x <- c(log_x, max(sort(unique(log10(x)))))

  10^log_x
}

name_performance <- function(measure_name, target_name) {
  target_name <- str_remove(target_name, "_010_day")

  glue::glue("{measure_name}({target_name})")
}
label_performance <- function(measure_name, target_name) {
  label <- ifelse(
    grepl("stent", target_name), "Yield of Testing", "Testing Decision"
  )
}

name_plot <- function(measure_name, dir_name, zoom = FALSE) {
  zoom_lab <- ifelse(zoom, "_zoomed", "")
  file.path(dir_name, glue("BEHAVIORAL_GBM_{toupper(measure_name)}{zoom_lab}.png"))
}

plot_performance <- function(which_measure, tb, min_diff, max_diff) {
  gg_tb <- filter(tb, measure_name == which_measure)

  best_n_df <- gg_tb %>%
    group_by(performance_label) %>%
    summarize(
      best_performance = max(performance),
      best_niters = niters[which(performance == best_performance)]
    ) %>%
    ungroup()
  best_n <- best_n_df$best_niters
  names(best_n) <- best_n_df$performance_label

  color_key <- a$disc_palette[1:2]
  names(color_key) <- c("Yield of Testing", "Testing Decision")

  y_lab <- switch(which_measure,
    auc = "AUC",
    r2 = "R2",
    LogLoss = "Log Loss",
    ppv = "Positive Predictive Value"
  )

  gg <- ggplot(
    gg_tb,
    aes(
      x = niters,
      y = performance,
      ymin = delta_lo,
      ymax = delta_hi,
      color = performance_label,
      fill = performance_label
    )
  ) +
    labs(
      x = "Num iterations",
      y = y_lab,
      # caption = glue("95% confidence intervals from 1,000 bootstrap iterations."),
      color = glue("Predictive accuracy ({y_lab}) of regularized model")
    ) +
    scale_x_continuous(
      trans = "log10",
      breaks = log_step(gg_tb$niters[gg_tb$niters > 0]),
      minor_breaks = NULL
    ) +
    scale_color_manual(values = color_key, aesthetics = c("color", "fill")) +
    guides(fill = FALSE) +
    theme_bw() +
    geom_ribbon(
      alpha = 0.4,
      color = NA
    ) +
    geom_line(linetype = "dashed") +
    geom_point() +
    theme(
      legend.position = "bottom",
      text = element_text(family = "Optima", size = 40)
    ) +
    guides(col = guide_legend(nrow = 2))

  lab_max <- glue("Complex\nk = {max_diff[which_measure]}")
  lab_min <- glue("Simple\nk = {min_diff[which_measure]}")

  lab_height_base <- max(gg_tb$performance) + 0.1
  for (grp in names(best_n)) {
    outcome <- ifelse(grepl("Yield", grp), "Yield", "Test")
    lab_height <- lab_height_base
    if (outcome == "Test") {
      lab_height <- lab_height - 0.05
    }
    lab <- glue("Maximum Performance: {outcome}\nk = {best_n[grp]}")
    gg <- gg +
      geom_vline(
        xintercept = best_n[grp], color = color_key[grp]
      ) +
      annotate("label",
        x = best_n[grp], y = lab_height, label = lab,
        size = 14, lineheight = 0.25, fill = color_key[grp], alpha = 0.4,
        hjust = 0, family = "Optima"
      )
  }

  return(gg)
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
performance_tb_raw <- readRDS(file.path(temp, "performance.rds"))
performance_dsmall_tb <- readRDS(file.path(temp, "performance_dsmall.rds")) %>%
  mutate(niters = 1 - (1 / depth)) %>%
  select(-depth)
cohort_fp <- file.path(paths$modeling$dir, "cohorts", "random")
ids <- readRDS(file.path(cohort_fp, "train_cohort_tested.rds"))

overnight_lab <- ""
test_cohort <- readRDS(glue(paths$analysis$test_cohort))

# D1 performance (0 AUC an R2)--------------------------------------------------
performance_d1_tb <- tibble(
  niters = rep(0, 4),
  target_name = rep(c("stent_or_cabg_010_day", "test_010_day"), 2),
  measure_name = c("r2", "r2", "auc", "auc"),
  mean_performance = rep(0, 4), std.error = rep(0, 4),
  delta_lo = rep(0, 4), delta_hi = rep(0, 4)
)

# Labeling ---------------------------------------------------------------------
message("Labeling performance...")
performance_tb <- performance_tb_raw %>%
  rbind(performance_dsmall_tb) %>%
  rbind(performance_d1_tb) %>%
  mutate(
    performance_name = pmap_chr(
      list(measure_name, target_name), name_performance
    )
  ) %>%
  mutate(
    performance_label = pmap_chr(
      list(measure_name, target_name), label_performance
    )
  ) %>%
  setnames("mean_performance", "performance") %>%
  filter(niters <= 400)

# Best Performance -------------------------------------------------------------
message("Getting best performance...")
diff_df <- performance_tb %>%
  select(niters, target_name, measure_name, performance) %>%
  spread(key = "target_name", value = "performance") %>%
  mutate(performance_diff = stent_or_cabg_010_day - test_010_day)

minmax_df <- diff_df %>%
  group_by(measure_name) %>%
  summarize(
    min_diff = min(performance_diff),
    min_k = niters[which(performance_diff == min(performance_diff))],
    max_diff = max(performance_diff),
    max_k = niters[which(performance_diff == max(performance_diff))]
  ) %>%
  ungroup()

min_diffs <- minmax_df$min_k
names(min_diffs) <- minmax_df$measure_name

max_diffs <- minmax_df$max_k
names(max_diffs) <- minmax_df$measure_name

# Plots ------------------------------------------------------------------------
message("Plotting...")
design <- tibble(which_measure = c("auc", "r2"))

plots <- transmute(
  design,
  filename = name_plot(which_measure, temp),
  plot = map(
    which_measure, plot_performance,
    tb = performance_tb, min_diff = min_diffs,
    max_diff = max_diffs
  ),
)

# Save
message("Saving...")
save <- pwalk(plots, ggsave, width = 10, height = 7, units = "in")

message("Done.")
